{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "NeurIPS Demo.ipynb",
      "private_outputs": true,
      "provenance": [],
      "collapsed_sections": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/google/jax-md/blob/master/notebooks/neurips_spotlight_demo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Az_-0KZ80yF3",
        "cellView": "form"
      },
      "source": [
        "#@title Imports & Utils\n",
        "import matplotlib\n",
        "import matplotlib.pyplot as plt\n",
        "import seaborn as sns\n",
        "  \n",
        "sns.set_style(style='white')\n",
        "\n",
        "import warnings\n",
        "warnings.filterwarnings(\"ignore\")\n",
        "\n",
        "!wget -O silica_train.npz https://www.dropbox.com/s/3dojk4u4di774ve/silica_train.npz?dl=0\n",
        "!wget https://github.com/google/jax-md/blob/master/examples/models/si_gnn.pickle?raw=true\n",
        "\n",
        "import numpy as onp\n",
        "from jax.api import device_put\n",
        "\n",
        "box_size = 10.862\n",
        "\n",
        "with open('silica_train.npz', 'rb') as f:\n",
        "  files = onp.load(f)\n",
        "  qm_positions, qm_energies, qm_forces = [device_put(x) for x in (files['arr_3'], files['arr_4'], files['arr_5'])]\n",
        "  qm_positions = qm_positions[:300]\n",
        "  qm_energies = qm_energies[:300]\n",
        "  qm_forces = qm_forces[:300]"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NutOCjdwpneq"
      },
      "source": [
        "## Demo\n",
        "\n",
        "www.github.com/google/jax-md -> notebooks -> neurips_spotlight_demo.ipynb"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ZDlsKyEkdQQP"
      },
      "source": [
        "!pip install jax-md"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rNTqOwVeebG9"
      },
      "source": [
        "Data from a quantum mechanical simulation of Silicon."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "alrV-gMVeazj"
      },
      "source": [
        "print(f'Box Size = {box_size}')\n",
        "print(qm_positions.shape)\n",
        "print(qm_energies.shape)\n",
        "print(qm_forces.shape)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IwT-WPwigGbg"
      },
      "source": [
        "Visualize states inside colab. "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "aSxTZZ5kdSL7"
      },
      "source": [
        "from jax_md.colab_tools import renderer\n",
        "\n",
        "renderer.render(box_size, \n",
        "                {\n",
        "                    'atom': renderer.Sphere(qm_positions[0]),\n",
        "                }, \n",
        "                resolution=[400, 400])"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cNuy3OLBf8DD"
      },
      "source": [
        "### Every simulation starts by defining a space."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "WBfwqm-3ePkX"
      },
      "source": [
        "from jax_md import space\n",
        "\n",
        "displacement_fn, shift_fn = space.periodic(box_size, wrapped=False)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MsSGBSO9iJOw"
      },
      "source": [
        "The `displacement_fn` computes displacement between points"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jy_2VaQug0cs"
      },
      "source": [
        "displacement_fn(qm_positions[0, 0], qm_positions[0, 3])"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8ASPtD-riZm5"
      },
      "source": [
        "The `shift_fn` moves points"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zj-Dbrn9hwhe"
      },
      "source": [
        "import jax.numpy as np\n",
        "\n",
        "shift_fn(qm_positions[0, 0], \n",
        "         np.array([1.0, 0.0, 0.0]))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8NB0RfoeimOx"
      },
      "source": [
        "### Load a pretrained Graph Neural Network"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "MV6ZSwvfiwBi"
      },
      "source": [
        "from jax_md import energy\n",
        "\n",
        "init_fn, energy_fn = energy.graph_network(displacement_fn, r_cutoff=3.0) "
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "1w8aGzlKjUFX"
      },
      "source": [
        "import pickle\n",
        "\n",
        "with open('si_gnn.pickle?raw=true', 'rb') as f:\n",
        "  params = pickle.load(f)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "oYk4h7wHwXm3"
      },
      "source": [
        "print(f'Predicted E = {energy_fn(params, qm_positions[0])}')\n",
        "print(f'Actual E = {qm_energies[0]}')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xk6FhI4kzGxh"
      },
      "source": [
        "import functools\n",
        "\n",
        "energy_fn = functools.partial(energy_fn, params) "
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "V6pp5fjzoU5a"
      },
      "source": [
        "from jax import vmap\n",
        "\n",
        "vectorized_energy_fn = vmap(energy_fn)\n",
        "plt.plot(qm_energies, vectorized_energy_fn(qm_positions), 'o')\n",
        "plt.show()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Ptv0B1eXxSJP"
      },
      "source": [
        "from jax_md import quantity\n",
        "\n",
        "force_fn = quantity.force(energy_fn)\n",
        "predicted_forces = force_fn(qm_positions[1])\n",
        "\n",
        "plt.plot(qm_forces[1].reshape((-1,)), \n",
        "         predicted_forces.reshape((-1,)), 'o')\n",
        "plt.show()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mxNhgDd8yCAG"
      },
      "source": [
        "### Using the network in a simulation"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "eJob7sBby3vs"
      },
      "source": [
        "from jax_md.simulate import nvt_nose_hoover\n",
        "\n",
        "K_B = 8.617e-5\n",
        "dt = 5e-3\n",
        "kT = K_B * 300 \n",
        "Si_mass = 2.91086E-3\n",
        "\n",
        "init_fn, step_fn = nvt_nose_hoover(energy_fn, shift_fn, dt, kT, tau=1.0)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ioU39DeGz7Rp"
      },
      "source": [
        "from jax import jit\n",
        "step_fn = jit(step_fn)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xa9BFd_ByAJM"
      },
      "source": [
        "from jax import random\n",
        "\n",
        "key = random.PRNGKey(0)\n",
        "state = init_fn(key, qm_positions[0], Si_mass)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "DS6fsZR9y3vu"
      },
      "source": [
        "positions = []\n",
        "\n",
        "for i in range(5000):\n",
        "  state = step_fn(state)\n",
        "\n",
        "  if i % 25 == 0:\n",
        "    positions += [state.position]"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8CUV49dN1Grk"
      },
      "source": [
        "positions = np.stack(positions)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "01u6DaDN09T4"
      },
      "source": [
        "renderer.render(box_size, \n",
        "                {\n",
        "                    'atom': renderer.Sphere(positions),\n",
        "                }, \n",
        "                resolution=[400, 400])"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}